import argparse
import time
import hashlib
import json
import logging
import os
import re
import math
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor
from llm_utils import LanguageModel
from openai import OpenAI


def read_jsonl(file_path: str):
    
    with open(file_path, 'r', encoding='utf-8') as f:
        return [json.loads(line) for line in f if line.strip()]

def write_json(data, output_file):
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(data, f, ensure_ascii=False, indent=4)

def _extract_optimal_value_with_llm(text: str, max_tries: int = 3) -> str | None:
    if not text:
        return None
        
    prompt = f"""
        Your task is to precisely extract and return exactly one line from the multi-line text provided below. This line must state the final optimization value (e.g., maximum profit, minimum cost, or total objective value).

        ## Core Instructions

        - **Exact Extraction**: The returned content must be a complete, unmodified line as it appears in the original text.  
        - **Single Output**: Your response must contain only the extracted line. Do not add any prefixes, suffixes, explanations, introductory phrases, or extra formatting.  
        - **Keyword Recognition**: Prioritize identifying and extracting the line that contains common optimization keywords such as:  
        - `cost`  
        - `profit`  
        - `objective`  
        - `value`  
        - `revenue`  
        - `optimal value`
        - 'Total'

        Text to analyze:
        ---
        {text}
        """
    
    for attempt in range(max_tries):
        try:
            response = client.chat.completions.create(
                model=MODEL_NAME,
                messages=[{'role': 'user', 'content': prompt}],
                temperature=0.0,
                max_tokens=100
            )
            extracted_text = response.choices[0].message.content.strip()
            if extracted_text:
                logging.debug(f"Successfully extracted optimal value: '{extracted_text}'")
                return extracted_text
        except Exception as e:
            logging.warning(f"LLM call for extraction failed on attempt {attempt + 1}/{max_tries}. Error: {e}")
        if attempt < max_tries - 1:
            time.sleep(0.5)
            
    logging.error(f"Failed to extract optimal value from text: '{text[:100]}...'")
    return None

def _parse_and_compare_numbers(val1_str: str, val2_str: str) -> bool:
    try:
        num_pattern = re.compile(r'[-]?[\d,]*\.?\d+')
        
        num1_match = num_pattern.findall(val1_str)
        num2_match = num_pattern.findall(val2_str)

        if not num1_match or not num2_match:
            logging.warning(f"Could not find a number in one of the strings: '{val1_str}' vs '{val2_str}'")
            return False

        num1 = float(num1_match[-1].replace(',', ''))
        num2 = float(num2_match[-1].replace(',', ''))
        
        is_correct = math.isclose(num1, num2, rel_tol=1e-5, abs_tol=1e-5)
        logging.debug(f"Comparing parsed numbers: {num1} vs {num2}. Result: {is_correct}")
        return is_correct
        
    except (ValueError, TypeError) as e:
        logging.error(f"Error parsing numbers for comparison: '{val1_str}' vs '{val2_str}'. Error: {e}")
        return False

def check_correctness(correct_answer, solution_output: str) -> bool:

    solution_output = solution_output or ""
    if not len(solution_output):
        return False
    
   
    logging.info("Step 1: Extracting optimal values...")
    correct_value_str = correct_answer
    model_value_str = _extract_optimal_value_with_llm(solution_output)
    
    #import pdb; pdb.set_trace()
    if not correct_value_str or not model_value_str:
        logging.warning("Failed to extract one or both optimal values. Cannot compare.")
        return False
    
    logging.info("Step 2: Parsing and comparing extracted values...")
    return _parse_and_compare_numbers(correct_value_str, model_value_str)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_path', type=str, required=True, help='Path to the LLM model directory')
    args = parser.parse_args()

    OPENAI_API_KEY = 'FAKE_API_KEY'
    
    MODEL_NAME = ''
    model_path = args.model_path

    input_dir = ""

    dataset_names = ['IndustryOR']
    
    max_new_tokens = 4096
    temperature=0.0
    top_p = 0.9
    n = 1
    tensor_parallel_size = 4
    api_endpoint="http://127.0.0.1:8000/v1/"


    logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
    logger = logging.getLogger(__name__)
    client = OpenAI(base_url=api_endpoint, api_key=OPENAI_API_KEY) 

    #import pdb; pdb.set_trace()
    LLM = LanguageModel(
        model_path=model_path,
        temperature=temperature,
        top_p=top_p,
        tensor_parallel_size=tensor_parallel_size
        )
    for dataset_name in tqdm(dataset_names, desc="All Datasets"):
        
        input_file = os.path.join(input_dir, f"{dataset_name}.jsonl")
        
        if not os.path.exists(input_file):
            logging.warning(f"Input file not found for dataset '{dataset_name}', skipping: {input_file}")
            continue

        output_dir = f"results_one/{dataset_name}"
        os.makedirs(output_dir, exist_ok=True)
        output_file = os.path.join(output_dir, f"{os.path.basename(model_path)}_results.json")

        logging.info(f"Processing dataset: {dataset_name}")
        data = read_jsonl(input_file)
        
        for d in tqdm(data, desc=f"Processing {dataset_name}", leave=False):
    
            question = d['question']
            correct_answer = str(d['answer'])

            executor = ThreadPoolExecutor(max_workers=10)

            outputs = LLM.generate_full_or_solutions(question, n)

            outputs_to_check = [output[-1] for output in outputs]

            correctnesses = list(executor.map(check_correctness, [correct_answer for _ in range(n)], outputs_to_check))

            d['correctnesses'] = correctnesses
            d['outputs'] = outputs
            d['code_output'] = outputs_to_check
        
        write_json(data, output_file)
        logging.info(f"Processing for dataset '{dataset_name}' complete. Results saved to {output_file}")
    
    
































